/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#pragma once

#include <cuda.h>
#include <cuda_runtime.h>

#include <exception>
#include <iostream>
#include <memory>
#include <type_traits>
#include <vector>

#define CUDA_CHECK(F)                                                  \
    do                                                                 \
    {                                                                  \
        auto ret = F;                                                  \
        if (ret != cudaSuccess)                                        \
        {                                                              \
            auto str = std::string(#F ": ") + cudaGetErrorString(ret); \
            if (!std::uncaught_exceptions())                           \
                throw std::runtime_error{str};                         \
            else                                                       \
                std::cerr << "Error: " << str << std::endl;            \
        }                                                              \
    } while (0)

#define CUDA_CHECK_FATAL(F)                                            \
    do                                                                 \
    {                                                                  \
        auto ret = F;                                                  \
        if (ret != cudaSuccess)                                        \
        {                                                              \
            auto str = std::string(#F ": ") + cudaGetErrorString(ret); \
            std::cerr << "Error: " << str << std::endl;                \
            std::terminate();                                          \
        }                                                              \
    } while (0)

namespace cuda {

template<typename T>
struct is_unbounded_array_t : std::false_type{};

template<typename T>
struct is_unbounded_array_t<T[]> : std::true_type{};

template<typename T>
constexpr inline bool is_unbounded_array_v = is_unbounded_array_t<T>::value;

template<typename T>
struct DeviceDeleter
{
    static_assert(std::is_trivially_destructible<std::decay_t<T>>::value, "Type must have a trivial destructor");

    void operator()(std::remove_extent_t<T>* ptr) noexcept {
        if(ptr) {
            CUDA_CHECK_FATAL(cudaFree(ptr));
        }
    }
};

template<typename T>
struct PinnedHostDeleter
{
    static_assert(std::is_trivially_destructible<std::decay_t<T>>::value, "Type must have a trivial destructor");

    void operator()(std::remove_extent_t<T>* ptr) noexcept {
        if(ptr) {
            CUDA_CHECK_FATAL(cudaHostUnregister(ptr));
            std::default_delete<T>{}(ptr);
        }
    }
};

template<typename T>
using DevicePtr = std::unique_ptr<T, DeviceDeleter<T>>;

template<typename T>
using PinnedHostPtr = std::unique_ptr<T, PinnedHostDeleter<T>>;

template<typename T>
DevicePtr<std::enable_if_t<is_unbounded_array_v<T>, T>> make_device_ptr(size_t n)
{
    void* ptr;
    CUDA_CHECK(cudaMalloc(&ptr, n * sizeof(std::remove_extent_t<T>)));
    return DevicePtr<T>(reinterpret_cast<std::remove_extent_t<T>*>(ptr));
}

template<typename T>
PinnedHostPtr<T> make_pinned_host_ptr()
{
    auto ptr = std::make_unique<T>();
    CUDA_CHECK(cudaHostRegister(ptr.get(), sizeof(T), 0));
    return PinnedHostPtr<T>(reinterpret_cast<std::remove_extent_t<T>*>(ptr.release()));
}

template<typename T>
PinnedHostPtr<std::enable_if_t<is_unbounded_array_v<T>, T>> make_pinned_host_ptr(size_t n)
{
    auto ptr = std::make_unique<T>(n);
    CUDA_CHECK(cudaHostRegister(ptr.get(), n * sizeof(std::remove_extent_t<T>), 0));
    return PinnedHostPtr<T>(reinterpret_cast<std::remove_extent_t<T>*>(ptr.release()));
}

void memcpy_to_device(void* dst, const void* src, size_t size)
{
    CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
}

void memcpy_to_host(void* dst, const void* src, size_t size)
{
    CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
}

template<typename T>
void copy_to_device(T* dst, const T* src, size_t size)
{
    memcpy_to_device(dst, src, size * sizeof(T));
}

template<typename T>
void copy_to_host(T* dst, const T* src, size_t size)
{
    memcpy_to_host(dst, src, size * sizeof(T));
}

class CudaStream {
public:
    CudaStream() {
        CUDA_CHECK(cudaStreamCreate(&stream_));
    }

    CudaStream(const CudaStream&) = delete;

    ~CudaStream() {
        CUDA_CHECK_FATAL(cudaStreamDestroy(stream_));
    }

    operator cudaStream_t() const { return stream_; }

    void sync() { CUDA_CHECK(cudaStreamSynchronize(stream_)); }

private:
    cudaStream_t stream_;
};

class CudaEvent {
public:
    CudaEvent() {
        CUDA_CHECK(cudaEventCreate(&event_));
    }

    void record(cudaStream_t stream) {
        CUDA_CHECK(cudaEventRecord(event_, stream));
    }

    void sync() {
        CUDA_CHECK(cudaEventSynchronize(event_));
    }

    float elapsed_time_ms(cudaEvent_t end) const {
        float ms;
        CUDA_CHECK(cudaEventElapsedTime(&ms, event_, end));
        return ms;
    }

    CudaEvent(const CudaEvent&) = delete;
    operator cudaEvent_t() const { return event_; }

private:
    cudaEvent_t event_;
};

void memcpy_to_device(void* dst, const void* src, size_t size, cudaStream_t stream)
{
    CUDA_CHECK(cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream));
}

void memcpy_to_host(void* dst, const void* src, size_t size, cudaStream_t stream)
{
    CUDA_CHECK(cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream));
}

template<typename T>
void copy_to_device(T* dst, const T* src, size_t size, cudaStream_t stream)
{
    memcpy_to_device(dst, src, size * sizeof(T), stream);
}

template<typename T>
void copy_to_host(T* dst, const T* src, size_t size, cudaStream_t stream)
{
    memcpy_to_host(dst, src, size * sizeof(T), stream);
}

template<typename T>
class DeviceBuffer {
public:
    explicit DeviceBuffer(size_t size)
    : data_{make_device_ptr<T[]>(size)}
    , size_{size}
    {}

    T* data() { return data_.get(); }
    const T* data() const { return data_.get(); }

    size_t size() const { return size_; }

private:
    DevicePtr<T[]> data_;
    const size_t size_;
};

template<typename T>
class PinnedHostBuffer {
public:
    explicit PinnedHostBuffer(size_t size)
    : data_{make_pinned_host_ptr<T[]>(size)}
    , size_{size}
    {}

    T* data() { return data_.get(); }
    const T* data() const { return data_.get(); }

    size_t size() const { return size_; }

    T& operator[](size_t i) { return data_[i]; }
    const T& operator[](size_t i) const { return data_[i]; }

private:
    PinnedHostPtr<T[]> data_;
    const size_t size_;
};

template<typename T>
class DualBuffer {
public:
    explicit DualBuffer(size_t size)
    : host_{make_pinned_host_ptr<T[]>(size)}
    , device_{make_device_ptr<T[]>(size)}
    , size_{size}
    {}

    T* host() { return host_.get(); }
    const T* host() const { return host_.get(); }
    T* device() { return device_.get(); }
    const T* device() const { return device_.get(); }

    size_t size() const { return size_; }

    void copy_to_device() {
        copy_to_device(device(), host(), size());
    }

    void copy_to_host() {
        copy_to_host(host(), device(), size());
    }

    void copy_to_device(cudaStream_t stream) {
        ::cuda::copy_to_device(device(), host(), size(), stream);
    }

    void copy_to_host(cudaStream_t stream) {
        ::cuda::copy_to_host(host(), device(), size(), stream);
    }
private:
    PinnedHostPtr<T[]> host_;
    DevicePtr<T[]> device_;
    const size_t size_;
};

inline std::vector<cudaDeviceProp> get_device_props()
{
    std::vector<cudaDeviceProp> device_props;
    int device_count;
    CUDA_CHECK(cudaGetDeviceCount(&device_count));
    for(auto i = 0; i < device_count; ++i) {
        device_props.emplace_back();
        CUDA_CHECK(cudaGetDeviceProperties(&device_props.back(), i));
    }

    return device_props;
}

} // namespace cuda